import os
import numpy as np
import torch
from PIL import Image
import cv2
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
from tqdm import tqdm


class MetricsEvaluator:
    """
    Class for evaluating image quality metrics between original, watermarked, and attacked images.
    Works with image directories rather than tensors.
    Supports both reference-based and no-reference metrics using pyiqa.
    """

    def __init__(self, original_dir=None, watermarked_dir=None, attacked_dir=None,
                 non_refer_metrics=None, refer_metrics=None):
        """
        Initialize the metrics evaluator.

        Args:
            original_dir: Directory containing original images (no watermark)
            watermarked_dir: Directory containing watermarked images
            attacked_dir: Directory containing attacked watermarked images
            non_refer_metrics: List of no-reference metrics to use
            refer_metrics: List of reference-based metrics to use
        """
        self.original_dir = original_dir
        self.watermarked_dir = watermarked_dir
        self.attacked_dir = attacked_dir
        self.to_tensor = transforms.ToTensor()

        # Default no-reference metrics
        self.non_refer_metrics = non_refer_metrics or ['niqe', 'brisque']

        # Default reference-based metrics
        self.refer_metrics = refer_metrics or ['lpips', 'psnr', 'ssim']

        # Initialize pyiqa metrics
        self._init_pyiqa_metrics()

    def _init_pyiqa_metrics(self):
        """
        Initialize pyiqa metrics based on the provided lists.
        """
        try:
            import pyiqa

            # Initialize no-reference metrics
            self.non_refer_metric_models = {}
            for metric_name in self.non_refer_metrics:
                try:
                    self.non_refer_metric_models[metric_name] = pyiqa.create_metric(metric_name)
                    print(f"Successfully initialized no-reference metric: {metric_name}")
                except Exception as e:
                    print(f"Failed to initialize no-reference metric {metric_name}: {e}")

            # Initialize reference-based metrics
            self.refer_metric_models = {}
            for metric_name in self.refer_metrics:
                if metric_name.lower() in ['psnr', 'ssim']:
                    # These are handled by skimage, not pyiqa
                    continue
                try:
                    self.refer_metric_models[metric_name] = pyiqa.create_metric(metric_name)
                    print(f"Successfully initialized reference-based metric: {metric_name}")
                except Exception as e:
                    print(f"Failed to initialize reference-based metric {metric_name}: {e}")

        except ImportError:
            print("pyiqa not found. Install with 'pip install pyiqa'.")
            self.non_refer_metric_models = {}
            self.refer_metric_models = {}

    def set_directories(self, original_dir=None, watermarked_dir=None, attacked_dir=None):
        """
        Update the directories for evaluation.

        Args:
            original_dir: Directory containing original images (no watermark)
            watermarked_dir: Directory containing watermarked images
            attacked_dir: Directory containing attacked watermarked images
        """
        if original_dir:
            self.original_dir = original_dir
        if watermarked_dir:
            self.watermarked_dir = watermarked_dir
        if attacked_dir:
            self.attacked_dir = attacked_dir

    def set_metrics(self, non_refer_metrics=None, refer_metrics=None):
        """
        Update the metrics to be used for evaluation.

        Args:
            non_refer_metrics: List of no-reference metrics to use
            refer_metrics: List of reference-based metrics to use
        """
        if non_refer_metrics:
            self.non_refer_metrics = non_refer_metrics
        if refer_metrics:
            self.refer_metrics = refer_metrics

        # Reinitialize pyiqa metrics
        self._init_pyiqa_metrics()

    def load_image(self, img_path):
        """
        Load an image from path and convert to a tensor in [0,1] range.

        Args:
            img_path: Path to the image file

        Returns:
            Tensor representation of the image in [0,1] range
        """
        try:
            img = Image.open(img_path).convert('RGB')
            img_tensor = self.to_tensor(img)  # [0,1] range
            return img_tensor
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return None

    def calculate_psnr(self, img1, img2):
        """
        Calculate PSNR between two images.

        Args:
            img1: First image tensor [0,1]
            img2: Second image tensor [0,1]

        Returns:
            PSNR value
        """
        if img1.shape != img2.shape:
            raise ValueError(f"Images must have the same dimensions. Got {img1.shape} and {img2.shape}")

        # Convert to numpy arrays in [0, 255] range for PSNR calculation
        img1_np = img1.cpu().numpy().transpose(1, 2, 0) * 255.0
        img2_np = img2.cpu().numpy().transpose(1, 2, 0) * 255.0

        return psnr(img1_np, img2_np)

    def calculate_ssim(self, img1, img2):
        """
        Calculate SSIM between two images.

        Args:
            img1: First image tensor [0,1]
            img2: Second image tensor [0,1]

        Returns:
            SSIM value
        """
        if img1.shape != img2.shape:
            raise ValueError(f"Images must have the same dimensions. Got {img1.shape} and {img2.shape}")

        # Convert to numpy arrays in [0, 255] range for SSIM calculation
        img1_np = img1.cpu().numpy().transpose(1, 2, 0) * 255.0
        img2_np = img2.cpu().numpy().transpose(1, 2, 0) * 255.0

        # Calculate SSIM for RGB channels
        return ssim(img1_np, img2_np, channel_axis=2, data_range=255)

    def calculate_lpips(self, img1, img2):
        """
        Calculate LPIPS (Learned Perceptual Image Patch Similarity) between two images.

        Args:
            img1: First image tensor [0,1]
            img2: Second image tensor [0,1]

        Returns:
            LPIPS value
        """
        if 'lpips' in self.refer_metric_models:
            # Using pyiqa's LPIPS
            if img1.dim() == 3:
                img1 = img1.unsqueeze(0)
                img2 = img2.unsqueeze(0)

            with torch.no_grad():
                lpips_value = self.refer_metric_models['lpips'](img1, img2).item()
            return lpips_value
        else:
            # Fallback to original LPIPS implementation
            try:
                import lpips
                loss_fn = lpips.LPIPS(net='alex', version='0.1')
                # Convert to [-1,1] range as required by LPIPS
                img1_lpips = img1 * 2 - 1
                img2_lpips = img2 * 2 - 1
                # Add batch dimension if needed
                if img1_lpips.dim() == 3:
                    img1_lpips = img1_lpips.unsqueeze(0)
                    img2_lpips = img2_lpips.unsqueeze(0)

                with torch.no_grad():
                    lpips_value = loss_fn(img1_lpips, img2_lpips).item()
                return lpips_value
            except ImportError:
                print("LPIPS not available. Install with 'pip install lpips' or 'pip install pyiqa'.")
                return 0.0

    def calculate_non_reference_metrics(self, img):
        """
        Calculate no-reference image quality metrics using pyiqa.

        Args:
            img: Image tensor [0,1]

        Returns:
            Dictionary of no-reference metrics
        """
        non_ref_metrics = {}

        if not self.non_refer_metric_models:
            return non_ref_metrics

        # Add batch dimension if needed
        if img.dim() == 3:
            img_batch = img.unsqueeze(0)
        else:
            img_batch = img

        # Calculate each no-reference metric
        for metric_name, model in self.non_refer_metric_models.items():
            try:
                with torch.no_grad():
                    metric_value = model(img_batch).item()
                non_ref_metrics[metric_name] = metric_value
            except Exception as e:
                print(f"Error calculating {metric_name}: {e}")
                non_ref_metrics[metric_name] = 0.0

        return non_ref_metrics

    def calculate_reference_metrics(self, ref_img, test_img):
        """
        Calculate reference-based image quality metrics using pyiqa.

        Args:
            ref_img: Reference image tensor [0,1]
            test_img: Test image tensor [0,1]

        Returns:
            Dictionary of reference-based metrics
        """
        ref_metrics = {}

        if not self.refer_metric_models:
            return ref_metrics

        # Add batch dimension if needed
        if ref_img.dim() == 3:
            ref_img_batch = ref_img.unsqueeze(0)
            test_img_batch = test_img.unsqueeze(0)
        else:
            ref_img_batch = ref_img
            test_img_batch = test_img

        # Calculate each reference-based metric
        for metric_name, model in self.refer_metric_models.items():
            try:
                with torch.no_grad():
                    metric_value = model(ref_img_batch, test_img_batch).item()
                ref_metrics[metric_name] = metric_value
            except Exception as e:
                print(f"Error calculating {metric_name}: {e}")
                ref_metrics[metric_name] = 0.0

        return ref_metrics

    def evaluate_all_metrics(self, original_to_watermarked=True, original_to_attacked=True):
        """
        Evaluate all metrics for images in the specified directories.

        Args:
            original_to_watermarked: Whether to evaluate metrics between original and watermarked images
            original_to_attacked: Whether to evaluate metrics between original and attacked images

        Returns:
            Dictionary of metrics
        """
        if not self.watermarked_dir:
            raise ValueError("Watermarked directory must be specified")

        # Get list of watermarked images
        watermarked_images = [f for f in os.listdir(self.watermarked_dir)
                              if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]

        # Initialize metrics dictionary
        metrics = {
            'psnr_orig_wm': [],
            'ssim_orig_wm': [],
            'lpips_orig_wm': [],
            'psnr_orig_att': [],
            'ssim_orig_att': [],
            'lpips_orig_att': []
        }

        # Add reference-based metrics from pyiqa
        for metric in self.refer_metric_models:
            if metric.lower() != 'lpips':  # lpips is already included
                metrics[f'{metric}_orig_wm'] = []
                metrics[f'{metric}_orig_att'] = []

        # Add no-reference metrics from pyiqa
        for metric in self.non_refer_metric_models:
            metrics[f'{metric}_original'] = []
            metrics[f'{metric}_watermarked'] = []
            metrics[f'{metric}_attacked'] = []

        for img_name in tqdm(watermarked_images, desc="Evaluating metrics"):
            # Load watermarked image
            wm_path = os.path.join(self.watermarked_dir, img_name)
            wm_img = self.load_image(wm_path)
            att_path = os.path.join(self.attacked_dir, img_name)
            att_img = self.load_image(att_path)
            # Calculate no-reference metrics for watermarked image
            non_ref_wm = self.calculate_non_reference_metrics(wm_img)
            for metric_name, value in non_ref_wm.items():
                metrics[f'{metric_name}_watermarked'].append(value)
            # Calculate no-reference metrics for attacked image
            non_ref_att = self.calculate_non_reference_metrics(att_img)
            for metric_name, value in non_ref_att.items():
                metrics[f'{metric_name}_attacked'].append(value)
            

            # Calculate metrics between original and watermarked
            if original_to_watermarked and self.original_dir:
                orig_path = os.path.join(self.original_dir, img_name)
                if os.path.exists(orig_path):
                    orig_img = self.load_image(orig_path)
                    if orig_img is not None:
                        # Resize if necessary
                        if orig_img.shape != wm_img.shape:
                            orig_img = transforms.Resize((wm_img.shape[1], wm_img.shape[2]))(orig_img)

                        # Calculate no-reference metrics for original image
                        non_ref_orig = self.calculate_non_reference_metrics(orig_img)
                        for metric_name, value in non_ref_orig.items():
                            metrics[f'{metric_name}_original'].append(value)

                        # Calculate traditional metrics
                        metrics['psnr_orig_wm'].append(self.calculate_psnr(orig_img, wm_img))
                        metrics['ssim_orig_wm'].append(self.calculate_ssim(orig_img, wm_img))
                        metrics['lpips_orig_wm'].append(self.calculate_lpips(orig_img, wm_img))

                        # Calculate additional reference-based metrics
                        ref_metrics = self.calculate_reference_metrics(orig_img, wm_img)
                        for metric_name, value in ref_metrics.items():
                            if metric_name.lower() != 'lpips':  # lpips is already included
                                metrics[f'{metric_name}_orig_wm'].append(value)

            # Calculate metrics between original and attacked
            if original_to_attacked and self.attacked_dir and self.original_dir:
                orig_path = os.path.join(self.original_dir, img_name)
                if os.path.exists(att_path) and os.path.exists(orig_path):
                    orig_img = self.load_image(orig_path)
                    if att_img is not None and orig_img is not None:
                        # Resize if necessary
                        if orig_img.shape != att_img.shape:
                            orig_img = transforms.Resize((att_img.shape[1], att_img.shape[2]))(orig_img)

                        # Calculate traditional metrics
                        metrics['psnr_orig_att'].append(self.calculate_psnr(orig_img, att_img))
                        metrics['ssim_orig_att'].append(self.calculate_ssim(orig_img, att_img))
                        metrics['lpips_orig_att'].append(self.calculate_lpips(orig_img, att_img))

                        # Calculate additional reference-based metrics
                        ref_metrics = self.calculate_reference_metrics(orig_img, att_img)
                        for metric_name, value in ref_metrics.items():
                            if metric_name.lower() != 'lpips':  # lpips is already included
                                metrics[f'{metric_name}_orig_att'].append(value)

        # Calculate average metrics
        result_metrics = {}
        for key, values in metrics.items():
            if values:
                result_metrics[key] = np.mean(values)
                result_metrics[f"{key}_std"] = np.std(values)
            else:
                result_metrics[key] = 0.0
                result_metrics[f"{key}_std"] = 0.0

        return result_metrics

    def evaluate_image_metrics(self, img_name):
        """
        Evaluate metrics for a single image.

        Args:
            img_name: Name of the image file

        Returns:
            Dictionary of metrics for the image
        """
        metrics = {}

        # Load watermarked image
        wm_path = os.path.join(self.watermarked_dir, img_name)
        if not os.path.exists(wm_path):
            return metrics

        wm_img = self.load_image(wm_path)
        if wm_img is None:
            return metrics

        # Calculate no-reference metrics for watermarked image
        non_ref_wm = self.calculate_non_reference_metrics(wm_img)
        for metric_name, value in non_ref_wm.items():
            metrics[f'{metric_name}_watermarked'] = value

        # Calculate metrics between original and watermarked
        if self.original_dir:
            orig_path = os.path.join(self.original_dir, img_name)
            if os.path.exists(orig_path):
                orig_img = self.load_image(orig_path)
                if orig_img is not None:
                    # Resize if necessary
                    if orig_img.shape != wm_img.shape:
                        orig_img = transforms.Resize((wm_img.shape[1], wm_img.shape[2]))(orig_img)

                    # Calculate no-reference metrics for original image
                    non_ref_orig = self.calculate_non_reference_metrics(orig_img)
                    for metric_name, value in non_ref_orig.items():
                        metrics[f'{metric_name}_original'] = value

                    # Calculate traditional metrics
                    metrics['psnr_orig_wm'] = self.calculate_psnr(orig_img, wm_img)
                    metrics['ssim_orig_wm'] = self.calculate_ssim(orig_img, wm_img)
                    metrics['lpips_orig_wm'] = self.calculate_lpips(orig_img, wm_img)

                    # Calculate additional reference-based metrics
                    ref_metrics = self.calculate_reference_metrics(orig_img, wm_img)
                    for metric_name, value in ref_metrics.items():
                        if metric_name.lower() != 'lpips':  # lpips is already included
                            metrics[f'{metric_name}_orig_wm'] = value

        # Calculate metrics between original and attacked
        if self.original_dir and self.attacked_dir:
            orig_path = os.path.join(self.original_dir, img_name)
            att_path = os.path.join(self.attacked_dir, img_name)

            if os.path.exists(orig_path) and os.path.exists(att_path):
                orig_img = self.load_image(orig_path)
                att_img = self.load_image(att_path)

                if orig_img is not None and att_img is not None:
                    # Resize if necessary
                    if orig_img.shape != att_img.shape:
                        orig_img = transforms.Resize((att_img.shape[1], att_img.shape[2]))(orig_img)

                    # Calculate no-reference metrics for attacked image
                    non_ref_att = self.calculate_non_reference_metrics(att_img)
                    for metric_name, value in non_ref_att.items():
                        metrics[f'{metric_name}_attacked'] = value

                    # Calculate traditional metrics
                    metrics['psnr_orig_att'] = self.calculate_psnr(orig_img, att_img)
                    metrics['ssim_orig_att'] = self.calculate_ssim(orig_img, att_img)
                    metrics['lpips_orig_att'] = self.calculate_lpips(orig_img, att_img)

                    # Calculate additional reference-based metrics
                    ref_metrics = self.calculate_reference_metrics(orig_img, att_img)
                    for metric_name, value in ref_metrics.items():
                        if metric_name.lower() != 'lpips':  # lpips is already included
                            metrics[f'{metric_name}_orig_att'] = value

        return metrics